Union Find
1. Union Find
(1) 概要
当初は要素がバラバラ
「2つの要素が同じグループになる」という操作を何度か行う
「要素Aと要素Bが同じグループか?」という問いにおよそO(N)で答える.
典型問題:グラフの2点を枝で連結し続け,ある2点が到達可能か?
あとはAtCoderとかAOJの提出一覧を漁っていくのも面白い
(2) 実装例
0-indexedなので,問題が1-indexedのときはindexを-1して使って下さい.
code:python
""" シンプルUnion Find """
from collections import defaultdict
class UnionFind:
""" UnionFind (0-index) """
def __init__(self, n: int):
"""
UnionFindの初期化
Parameters
----------
n : int
頂点数
"""
self.n = n
self.cycle = 0
def find(self, x: int) -> int:
"""
頂点xの根となる頂点番号を取得する
Parameters
----------
x : int
頂点番号
"""
return x
else:
self.parentsx = self.find(self.parentsx) def union(self, x: int, y: int) -> None:
"""
頂点xが属する集合と頂点yが属する集合を連結する
Parameters
----------
x : int
頂点番号
y : int
頂点番号
"""
x = self.find(x)
y = self.find(y)
# 閉路となる辺は繋げない
if x == y:
self.cycle += 1
return
if self.parentsx > self.parentsy: x, y = y, x
self.parentsx += self.parentsy def size(self, x: int) -> int:
"""
頂点xが属する集合の要素数を取得する
Parameters
----------
x : int
頂点番号
"""
def same(self, x: int, y: int) -> bool:
"""
頂点xと頂点yが同じ集合に属しているか判定する
Parameters
----------
x : int
頂点番号
y : int
頂点番号
"""
return self.find(x) == self.find(y)
def cycle_count(self) -> int:
"""
閉路の数を取得する
Notes
-----
閉路が出来るようなunionを行っていた際、
内部では実際にunionはせずに、閉路の数を加算している
"""
return self.cycle
def members(self, x: int) -> list:
"""
頂点xと同じ集合に属する要素を全て取得する
Parameters
----------
x : int
頂点番号
"""
root = self.find(x)
def roots(self) -> list:
""" 根である要素を全て取得する """
def group_count(self) -> int:
""" 全ての集合の数を取得する """
return len(self.roots())
def all_group_members(self) -> defaultdict(list):
"""
全ての集合とそれに属する要素を取得する
Returns
-------
all_group_members : defaultdict(list)
key : 根要素の頂点番号, value : keyを根とする全ての頂点番号を含むリスト
"""
group_members = defaultdict(list)
for member in range(self.n):
return group_members
def debug_print(self) -> None:
"""
デバッグ用に現在の状態をprint出力する
Notes
-----
"""
return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())
code:python
### 使用例 ###
uf = UnionFind(6) # 0~5までの6個の点を作成
uf.debug_print() # 0:0, 1:1, 2:2, 3:3, 4:4, 5:5 print("\n")
uf.union(0, 2) # 点0が属するグループと、点2が属するグループを結合
uf.debug_print() # 0:0,2, 1:1, 3:3, 4:4, 5:5 print("\n")
uf.union(1, 3) # 点1が属するグループと、点3が属するグループを結合
uf.debug_print() # 0:0,2, 1:1,3, 4:4, 5:5 print("\n")
uf.union(4, 5) # 点4が属するグループと、点5が属するグループを結合
print("\n")
uf.union(1, 4) # 点1が属するグループと、点4が属するグループを結合
print("\n")
print(uf.find(0)) # 点0が属するグループは? --> 0
print(uf.find(5)) # 点5が属するグループは? --> 1
print(uf.size(0)) # 点0が属するグループのメンバー数は? --> 2
print(uf.size(5)) # 点5が属するグループのメンバー数は? --> 4
print(uf.same(0, 2)) # 点0と点2は同じグループ? --> True
print(uf.same(0, 5)) # 点0と点5は同じグループ? --> False
print(uf.members(0)) # 点0と同じグループの点は? --> 0, 2 print(uf.members(5)) # 点5と同じグループの点は? --> 1, 3, 4, 5 print(uf.roots()) # 各グループの根となる点は? --> 0, 1 print(uf.group_count()) # グループ数は? --> 2
print(uf.all_group_members()) # グループ状態は? --> (defaultdict) {0: 0, 2, 1: 1, 3, 4, 5} uf.union(0, 2) # 点0と点2が属するグループが同じなので、閉路の数が加算される
print(uf.cycle_count()) # 閉路の数は? --> 1
1-index,点番号0~N-1以外のケースへの対応版
code: python
from collections import defaultdict
class UnionFind():
""" UnionFind (1-index) """
def __init__(self, n: int):
"""
UnionFindの初期化
Parameters
----------
n : int
頂点数
"""
self.n = n + 1
self.cycle = 0 # 閉路の数
def find(self, x: int) -> int:
"""
頂点xの根となる頂点番号を取得する
Parameters
----------
x : int
頂点番号
"""
return x
else:
self.parentsx = self.find(self.parentsx) def union(self, x: int, y: int) -> None:
"""
頂点xが属する集合と頂点yが属する集合を連結する
"""
x = self.find(x)
y = self.find(y)
# 閉路となる辺は繋げない
if x == y:
self.cycle += 1
return
# 集合を繋げるときに要素数の大きい方に小さいほうを繋げるようにする
# 2つの集合の要素数を比較して入れ替える
if self.parentsx > self.parentsy: x, y = y, x
self.parentsx += self.parentsy def size(self, x: int) -> int:
"""
頂点xが属する集合の要素数を取得する
Parameters
----------
x : int
頂点番号
"""
def same(self, x: int, y: int) -> bool:
"""
頂点xと頂点yが同じ集合に属しているか判定する
Parameters
----------
x : int
頂点番号
y : int
頂点番号
"""
return self.find(x) == self.find(y)
def cycle_count(self) -> int:
"""
閉路の数を取得する
Notes
-----
閉路が出来るようなunionを行っていた際、
内部では実際にunionはせずに、閉路の数を加算している
"""
return self.cycle
def members(self, x: int) -> list:
"""
頂点xと同じ集合に属する要素を全て取得する
Parameters
----------
x : int
頂点番号
"""
root = self.find(x)
def roots(self) -> list:
"""
根である要素を全て取得する
"""
def group_count(self) -> int:
"""
全ての集合の数を取得する
"""
return len(self.roots())
def all_group_members(self) -> defaultdict(list):
"""
全ての集合とそれに属する要素を取得する
Returns
-------
all_group_members : defaultdict(list)
key : 根要素の頂点番号, value : keyを根とする全ての頂点番号を含むリスト
"""
group_members = defaultdict(list)
for member in range(self.n):
if self.find(member) != 0:
return group_members
def debug_print(self) -> None:
"""
デバッグ用に現在の状態をprint出力する
Notes
-----
"""
print('\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items()))
code:python
### 使用例 ###
uf = UnionFind(6) # 0~6までの6個の点を作成。1-indexなので0は使わない
uf.debug_print() # 0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6 print("\n")
uf.union(1, 3) # 点1が属するグループに, 点3が属するグループを結合
uf.debug_print() # 0:0, 1:1, 3, 2:2, 4:4, 5:5, 6:6 print("\n")
uf.union(2, 4)
print("\n")
uf.union(5, 6)
print("\n")
uf.union(2, 5)
print("\n")
print(uf.find(1)) # 点1が属するグループは? --> 1
print(uf.find(6)) # 点6が属するグループは? --> 2
print(uf.size(1)) # 点1が属するグループのメンバー数は? --> 2
print(uf.size(6)) # 点5が属するグループのメンバー数は? --> 4
print(uf.same(1, 3)) # 点1と点3は同じグループ? --> True
print(uf.same(1, 6)) # 点1と点6は同じグループ? --> False
print(uf.members(1)) # 点1と同じグループの点は? --> 1, 3 print(uf.members(5)) # 点5と同じグループの点は? --> 1, 3, 4, 5 print(uf.roots()) # 各グループの根となる点は? --> 1, 2 print(uf.group_count()) # グループ数は? --> 2
print(uf.all_group_members()) # グループ状態は? --> (defaultdict) {1: 1, 3, 1: 1, 3, 4, 5} uf.union(0, 2) # 点0と点2が属するグループが同じなので、閉路の数が加算される
print(uf.cycle_count()) # 閉路の数は? --> 1
2. Union Find利用例
(1) グラフの連結性判定
無向グラフにおいて,2点間が到達可能か判定
連結成分の数もカウント可能.
有向グラフでも適用可能.その場合は,弱連結成分判定
code:python
n, m = list(map(int, input().split())) # 点数,枝数
uf = UnionFind(n)
for _ in range(m):
s, t = list(map(int, input().split()))
uf.union(s, t) # 枝の両端点を同じグループにする
q = int(input())
for _ in range(q):
s, t = list(map(int, input().split()))
print('yes' if uf.same(s, t) else 'no') # sとtが到達可能なら yes,そうでないなら no を出力
(2) グラフの最小全域木
無向グラフにおける最小全域木を求める.クラスカル法を利用.
枝のソートを逆にすれば,最大全域木も可能
code:python
from itertools import combinations
n = int(input()) # 点数
edges = [] # (距離,点1,点2) を要素とするリスト
for i, j in combinations(range(n), 2):
edges.append((aij, i, j)) edges.sort() # 最小全域木を求めるので,距離について小さい順にソート
# print(edges)
uf = UnionFind(n)
ans = 0
for d, i, j in edges:
if not uf.same(i, j):
uf.union(i, j) # 閉路ができなければ追加する
ans += d # 全域木の枝数は n-1なので,n-1本枝が追加されたらbreakしてもよい
print(ans)
3. 拡張版 Union Find
(1) 有向枝重み付きUnion Find
有向グラフの枝に重みがある場合に使用
有向枝$ (u, v)に対して,$ u\rightarrow vの重みが$ wなら,$ v\rightarrow uの重みは$ -w
点$ aから点$ bへ到着する最大の重みを出力.無限大のときはinf.到着不可能のときはnan
code: python
import math
from collections import defaultdict
import sys
sys.setrecursionlimit(10**7)
class WeightedUnionFind():
""" 重み付きUnionFind (0-index) """
def __init__(self, n: int):
"""
重み付きUnionFindの初期化
Parameters
----------
n : int
頂点数
"""
self.n = n
self.parents = -1 * self.n # 自身の親となる頂点番号 (自身が根の時は要素数) self.weight = 0 * self.n # 頂点の重み self.f_inf = float("inf")
self.f_nan = float("nan")
def find(self, x: int) -> int:
"""
頂点xの根となる頂点番号を取得する
Parameters
----------
x : int
頂点番号
"""
return x
else:
current_parent = self.parentsx self.parentsx = self.find(self.parentsx) def union(self, x: int, y: int, w: int) -> None:
"""
頂点x -> 頂点y のコストがwという状態を保持したまま、
頂点xが属する集合と頂点yが属する集合を連結する
Parameters
----------
x : int
始点となる頂点番号
y : int
終点となる頂点番号
w : int
辺の重み
Notes
-----
指定した頂点同士を連結した際、閉路全体の重みの差が0にならないなら重みはinfになる
"""
w -= self._func_weight(x)
w += self._func_weight(y)
x = self.find(x)
y = self.find(y)
if x == y:
if w:
return
else:
return
if self.parentsx > self.parentsy: x, y = y, x
else:
w = -w
self.parentsx += self.parentsy self.weightx = self.f_inf else:
def _func_weight(self, x: int) -> int:
""" 内部の処理で用いているだけなので使用しない """
def diff(self, x: int, y: int) -> int:
"""
頂点x -> 頂点y にかかるコストを取得する
Parameters
----------
x : int
始点となる頂点番号
y : int
終点となる頂点番号
Return
-----
連結している頂点同士を指定していた場合
頂点同士の距離を返す
(閉路が出来ている場合は閉路内の重みの差が0でないならinfを返す)
連結していない頂点同士を指定した場合
到達不可能なのでnanを返す
"""
if self.same(x, y):
cost = self._func_weight(x) - self._func_weight(y)
if math.isnan(cost):
return self.f_inf
return cost
else:
return self.f_nan
# return self._func_weight(x) - self._func_weight(y) if self.same(x, y) else self.f_nan
def size(self, x: int) -> int:
"""
頂点xが属する集合の要素数を取得する
Parameters
----------
x : int
頂点番号
"""
def same(self, x: int, y: int) -> bool:
"""
頂点xと頂点yが同じ集合に属しているか判定する
Parameters
----------
x : int
頂点番号
y : int
頂点番号
"""
return self.find(x) == self.find(y)
def members(self, x: int) -> list:
"""
頂点xと同じ集合に属する要素を全て取得する
Parameters
----------
x : int
頂点番号
"""
root = self.find(x)
def roots(self) -> list:
"""
根である要素を全て取得する
"""
def group_count(self) -> int:
"""
全ての集合の数を取得する
"""
return len(self.roots())
def all_group_members(self) -> defaultdict(list):
"""
全ての集合とそれに属する要素を取得する
Returns
-------
all_group_members : defaultdict(list)
key : 根要素の頂点番号, value : keyを根とする全ての頂点番号を含むリスト
"""
group_members = defaultdict(list)
for member in range(self.n):
return group_members
def debug_print(self) -> None:
"""
デバッグ用に現在の状態をprint出力する
Notes
-----
"""
print('\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items()))
code:python
### 使用例 ###
wuf = WeightedUnionFind(6) # 0-index. 点0~点5まで
wuf.union(0, 2, 4) # 点0 -> 点2 の移動にかかるコストが4となるように点0が属するグループと、点2が属するグループを結合
wuf.debug_print() # 0:0, 2, 1:1, 3:3, 4:4, 5:5 print(wuf.weight)
print(wuf.diff(0, 2)) # 4 <= 点1から点3へ移動したときのコスト
print("\n")
wuf.union(1, 2, 7) # 点1 -> 点2 の移動にかかるコストが7となるように点1が属するグループと、点2が属するグループを結合
print(wuf.diff(0, 1)) # -3 <= 点0 -> 点2 -> 点1 の移動となりコストは-3
print(wuf.diff(4, 5)) # nan <= 連結していない頂点は到達不可なのでnanが出力される
print("\n")
wuf.union(0, 1, 1) # 既に連結している同士の頂点を指定しているので閉路が出来る
print(wuf.diff(0, 1)) # inf <= 閉路内の重みの差が0で無いのでinfが出力される
print("\n")
wuf.union(3, 4, 1) # 点3 -> 点4 の移動にかかるコストが1となるように点3が属するグループと、点4が属するグループを結合
wuf.union(4, 3, -1) # 点4 -> 点3 の移動にかかるコストが-1となるように点4が属するグループと、点3が属するグループを結合
print(wuf.diff(3, 4)) # 1 <= 点3から点4へ移動したときのコスト
(2) 箱型 Union Find
箱に要素を入れる操作addがある. 要素が最初から全て箱に入っておらず,途中で箱に追加できる
Union Find操作の対象は箱.ただし,union(x, y)は併合ではなく,箱yの中の全要素を箱xへ移動
findで箱の中の要素のうち,根となる要素を返す.
find_boxで箱の番号を返す.根となる要素の親として,箱番号×(-1)が保存されている.
code:python
from collections import defaultdict
class BoxUnionFind():
def __init__(self, n):
self.parents = defaultdict(int) # 各要素の親の番号
def add(self, x, i): # 要素 xを箱 iに入れる
if self.rootsi is None: # 箱iに要素が入っていないとき,要素xを入れて根とする self.parentsx = -i # 要素xは根であり,箱iに入る else: # 箱xに別の要素が入っているとき
self.parentsx = self.rootsi def find(self, x): # 要素 xの入っている箱の根となる要素番号を返す
if self.parentsx < 0: # 根の親は 負数 return x
else:
self.parentsx = self.find(self.parentsx) # 経路圧縮 def find_box(self, x): # 要素 xの入っている箱の番号を返す
def union(self, x, y): # 箱Yを箱Xに移動(unionという名前だけど,併合ではない)
x_root = self.rootsx # 箱xの根となる要素番号(要素がないときは None) y_root = self.rootsy # 箱yの根となる要素番号(要素がないときは None) if y_root is None: # 箱Yが空なら何もしない
return
if x_root is None: # 箱Xだけ空なら,Yをそのまま移す
elif x_root != y_root: # 両方とも球が入っていて,根が異なるとき
# 使用方法
buf = BoxUnionFind(5) # 箱1~5を作成(1-indexed)
for i in range(1, 5+1):
buf.add(i, i) # 箱1に球1,箱2に球2,...,箱5に球5を入れる
print(buf.find_box(5)) # 球5の入る箱 --> 5
buf.union(1, 4) # 箱4の中身を箱1へ移動.箱1={1,4}, 箱4={}
buf.add(6, 1) # 箱1に球6を追加
buf.add(7, 4) # 箱4に球7を追加.箱1={1,4,6}, 箱4={7}
print(buf.find_box(7)) # 球7の入る箱 --> 4
buf.union(3, 1) # 箱1の中身を箱3へ移動.箱1={}, 箱3={3,1,4,6}
print(buf.find_box(4)) # 球4の入る箱 --> 3
buf.union(1, 4) # 箱4の中身を箱1へ移動.箱4={}, 箱1={7}
print(buf.find_box(7)) # 球7の入る箱 --> 1
print(buf.find_box(6)) # 球6の入る箱 --> 3